# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.


from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import torch.nn.functional as F

import torch
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.functional import scaled_dot_product_attention
from transformers.utils import ModelOutput

from flash_attn import flash_attn_varlen_func
from modeling.qwen2.modeling_qwen2 import (
    Qwen2Attention, 
    Qwen2MLP, 
    Qwen2MLP_Sparse, 
    Qwen2MoeSparseMoeBlock, 
    Qwen2PreTrainedModel, 
    Qwen2RMSNorm, 
    Qwen2RotaryEmbedding,
    apply_rotary_pos_emb,
)

from modeling.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config
import copy

import copy, gc
from contextlib import nullcontext

torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.accumulated_cache_size_limit = 4096
# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
flex_attention = torch.compile(flex_attention)


def compute_cos_mag(log_path, layer_idx, residual, packed_query_sequence, layer_type="Attn"): 

    # flatten 到 [N, D] 形式
    residual_flat = residual.view(-1, residual.size(-1))
    query_flat = packed_query_sequence.view(-1, packed_query_sequence.size(-1))

    # Cosine similarity: 每个位置
    cos_sim = F.cosine_similarity(residual_flat, query_flat, dim=-1)  # shape: [N]

    # Magnitude ratio: ||residual|| / ||query||
    residual_norm = torch.norm(query_flat - residual_flat, dim=-1)  # shape: [N]
    query_norm = torch.norm(query_flat, dim=-1)
    mag_ratio = (residual_norm) / (query_norm + 1e-6)  # 防止除以 0

    return cos_sim, mag_ratio

    # 统计结果
    # if log_path is not None:
    #     with open(log_path, 'a') as f:
    #         f.write(f"layer_idx: {layer_idx} | {layer_type} | Cosine similarity (mean): {cos_sim.mean().item():.4f} | Magnitude ratio (mean): {mag_ratio.mean().item():.4f}\n")
    # else: 
    #     print(f"layer_idx: {layer_idx} | {layer_type} | Cosine similarity (mean): {cos_sim.mean().item():.4f} | Magnitude ratio (mean): {mag_ratio.mean().item():.4f}\n")


def compute_attention_scores(query, key, attention_mask=None, is_causal=False, layer_idx=None):
    """
    Compute attention weights (scores) only. No value projection or output.
    
    Inputs:
        - query: [Q_len, QH, D]
        - key:   [K_len, KVH, D]
        - num_kv_heads: Number of key/value heads (for GQA/MQA)
    
    Returns:
        - attn_weights: [QH, Q_len, K_len]
    """
    Q_len, QH, D = query.shape
    K_len, KVH, _ = key.shape

    if QH != KVH:
        if QH % KVH != 0:
            raise ValueError(f"Cannot map {KVH} KV heads to {QH} query heads")
        factor = QH // KVH
        key = key.repeat_interleave(factor, dim=1)  # [K_len, QH, D]

    # [Q_len, QH, D] x [QH, D, K_len] → [QH, Q_len, K_len]
    # We'll permute for matmul
    query_t = query.permute(1, 0, 2)  # [QH, Q_len, D]
    key_t = key.permute(1, 2, 0)      # [QH, D, K_len]

    scores = torch.matmul(query_t, key_t) / (D ** 0.5)  # [QH, Q_len, K_len]

    if attention_mask is not None:
        scores += attention_mask  # shape should broadcast to [QH, Q_len, K_len]

    if is_causal:
        causal_mask = torch.tril(torch.ones(Q_len, K_len, device=query.device)).unsqueeze(0)  # [1, Q_len, K_len]
        scores = scores.masked_fill(causal_mask == 0, float("-inf"))

    attn_weights = F.softmax(scores, dim=-1)  # [QH, Q_len, K_len]
    return attn_weights


class Qwen2Config(_Qwen2Config):
    r"""
    This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
    Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of
    Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vocab_size (`int`, *optional*, defaults to 151936):
            Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`Qwen2Model`]
        hidden_size (`int`, *optional*, defaults to 4096):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 22016):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 32):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 32):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_key_value_heads (`int`, *optional*, defaults to 32):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
            The non-linear activation function (function or string) in the decoder.
        max_position_embeddings (`int`, *optional*, defaults to 32768):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
            Whether the model's input and output word embeddings should be tied.
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        rope_scaling (`Dict`, *optional*):
            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
            accordingly.
            Expected contents:
                `rope_type` (`str`):
                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
                    'llama3'], with 'default' being the original RoPE implementation.
                `factor` (`float`, *optional*):
                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
                    original maximum pre-trained length.
                `original_max_position_embeddings` (`int`, *optional*):
                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
                    pretraining.
                `attention_factor` (`float`, *optional*):
                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
                    computation. If unspecified, it defaults to value recommended by the implementation, using the
                    `factor` field to infer the suggested value.
                `beta_fast` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
                    ramp function. If unspecified, it defaults to 32.
                `beta_slow` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
                    ramp function. If unspecified, it defaults to 1.
                `short_factor` (`List[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `long_factor` (`List[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `low_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
                `high_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
        use_sliding_window (`bool`, *optional*, defaults to `False`):
            Whether to use sliding window attention.
        sliding_window (`int`, *optional*, defaults to 4096):
            Sliding window attention (SWA) window size. If not specified, will default to `4096`.
        max_window_layers (`int`, *optional*, defaults to 28):
            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.

    ```python
    >>> from transformers import Qwen2Model, Qwen2Config

    >>> # Initializing a Qwen2 style configuration
    >>> configuration = Qwen2Config()

    >>> # Initializing a model from the Qwen2-7B style configuration
    >>> model = Qwen2Model(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "qwen2"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        vocab_size=151936,
        hidden_size=4096,
        intermediate_size=22016,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=32,
        hidden_act="silu",
        max_position_embeddings=32768,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        tie_word_embeddings=False,
        rope_theta=10000.0,
        rope_scaling=None,
        use_sliding_window=False,
        sliding_window=4096,
        max_window_layers=28,
        attention_dropout=0.0,
        is_causal=True,
        _attn_implementation="flash_attention_2",
        qk_norm=True,
        layer_module="Qwen2DecoderLayer",
        freeze_und=False,
        # gen_text=False, 
        **kwargs,
    ):
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            hidden_act=hidden_act,
            max_position_embeddings=max_position_embeddings,
            initializer_range=initializer_range,
            rms_norm_eps=rms_norm_eps,
            use_cache=use_cache,
            tie_word_embeddings=tie_word_embeddings,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            use_sliding_window=use_sliding_window,
            sliding_window=sliding_window,
            max_window_layers=max_window_layers,
            attention_dropout=attention_dropout,
            is_causal=is_causal,
            _attn_implementation=_attn_implementation,
            **kwargs,
        )
        self.qk_norm = qk_norm
        self.layer_module = layer_module
        self.freeze_und = freeze_und


class NaiveCache:
    def __init__(self, num_layers):
        self.key_cache = {k: None for k in range(num_layers)}
        self.value_cache = {k: None for k in range(num_layers)}

    @property
    def num_layers(self):
        return len(self.key_cache)

    @property
    def seq_lens(self):
        if self.key_cache[0] is not None:
            return self.key_cache[0].shape[0]
        else:
            return 0


@dataclass
class BaseNavitOutputWithPast(ModelOutput):
    packed_query_sequence: torch.FloatTensor = None
    past_key_values: Optional[NaiveCache] = None
    router_logits: torch.FloatTensor = None


def pad_sequence(tensor, pad_size):
    H, L, D = tensor.shape
    pad_tensor = tensor.new_zeros((H, pad_size, D))
    return torch.cat([tensor, pad_tensor], dim=1)


class PackedAttention(Qwen2Attention):
    def __init__(self, config, layer_idx: Optional[int] = None):
        super().__init__(config, layer_idx)
        if self.config.qk_norm:
            self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
        else:
            self.q_norm = nn.Identity()
            self.k_norm = nn.Identity()
        
        
    def forward(self, *args, **kwargs):
        if self.training:
            return self.forward_train(*args, **kwargs)
        else:
            return self.forward_inference(*args, **kwargs)

    def forward_train(
        self,
        packed_sequence: torch.Tensor,
        sample_lens: List[int],
        attention_mask: List[torch.Tensor],
        packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
    ):
        packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim)
        packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
        packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)

        packed_query_states = self.q_norm(packed_query_states)
        packed_key_states = self.k_norm(packed_key_states)

        packed_cos, packed_sin = packed_position_embeddings
        packed_query_states, packed_key_states = apply_rotary_pos_emb(
            packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
        )

        if isinstance(attention_mask, List):
            packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
            packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim)
            packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
            packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)

            unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1)
            unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1)
            unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
            upacked_attn_output = []
            for query_states, key_states, value_states, attention_mask_per_sample in zip(
                unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
            ):
                with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
                    attn_output = scaled_dot_product_attention(
                        query_states.to(torch.bfloat16).unsqueeze(0), 
                        key_states.to(torch.bfloat16).unsqueeze(0), 
                        value_states.to(torch.bfloat16).unsqueeze(0),
                        attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
                    )
                upacked_attn_output.append(attn_output.squeeze(0))
            packed_attn_output = torch.cat(upacked_attn_output, dim=1)
        else:
            pad_size = sum(sample_lens) - packed_query_states.shape[0]
            packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size)
            packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size)
            packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
            packed_attn_output = flex_attention(
                packed_query_states.unsqueeze(0), 
                packed_key_states.unsqueeze(0), 
                packed_value_states.unsqueeze(0), 
                enable_gqa=True,
                block_mask=attention_mask,
            )
            end_index = packed_attn_output.shape[2] - pad_size
            packed_attn_output = packed_attn_output[0, :, :end_index, :]

        packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size)
        packed_attn_output = self.o_proj(packed_attn_output)

        return packed_attn_output

    def forward_inference(
        self,
        packed_query_sequence: torch.Tensor,
        query_lens: torch.Tensor,
        packed_query_position_embeddings: torch.Tensor,
        packed_query_indexes: torch.Tensor,
        past_key_values: Optional[NaiveCache] = None,
        key_values_lens: Optional[torch.Tensor] = None,
        packed_key_value_indexes: Optional[torch.Tensor] = None,
        update_past_key_values=True,
        is_causal=True,
    ):
        packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
        packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
        packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)

        packed_query_states = self.q_norm(packed_query_states)
        packed_key_states = self.k_norm(packed_key_states)

        packed_cos, packed_sin = packed_query_position_embeddings
        packed_query_states, packed_key_states = apply_rotary_pos_emb(
            packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
        )

        packed_query_states = packed_query_states.to(torch.bfloat16)
        packed_key_states = packed_key_states.to(torch.bfloat16)
        packed_value_states = packed_value_states.to(torch.bfloat16)

        if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
            past_key_states = past_key_values.key_cache[self.layer_idx]
            past_value_states = past_key_values.value_cache[self.layer_idx]

            seqlens = sum(query_lens) + sum(key_values_lens)
            merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
            merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
            merged_key_states[packed_query_indexes] = packed_key_states
            merged_key_states[packed_key_value_indexes] = past_key_states
            merged_value_states[packed_query_indexes] = packed_value_states
            merged_value_states[packed_key_value_indexes] = past_value_states
            key_values_lens = key_values_lens + query_lens
        else:
            merged_key_states = packed_key_states
            merged_value_states = packed_value_states
            key_values_lens = query_lens

        cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
        cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))

        packed_attn_output = flash_attn_varlen_func(
            q=packed_query_states,
            k=merged_key_states,
            v=merged_value_states,
            cu_seqlens_q=cu_seqlens_q.to(torch.int32),
            cu_seqlens_k=cu_seqlens_k.to(torch.int32),
            max_seqlen_q=max(query_lens).item(),
            max_seqlen_k=max(key_values_lens).item(),
            causal=is_causal,
            # return_attn_probs=True
        )
        packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
        packed_attn_output = self.o_proj(packed_attn_output)

        if update_past_key_values:
            past_key_values.key_cache[self.layer_idx] = merged_key_states
            past_key_values.value_cache[self.layer_idx] = merged_value_states

        return packed_attn_output, past_key_values


class PackedAttentionMoT(Qwen2Attention):
    def __init__(self, config, layer_idx: Optional[int] = None, sparse_mode="dense"):
        super().__init__(config, layer_idx)
        self.layer_idx = layer_idx
        if self.config.qk_norm:
            self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
        else:
            self.q_norm = nn.Identity()
            self.k_norm = nn.Identity()
            self.q_norm_moe_gen = nn.Identity()
            self.k_norm_moe_gen = nn.Identity()

        self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.mode = "und"
        self.packed_text_indexes = None
        self.packed_vae_token_indexes = None

        self.sparse_mode = sparse_mode

        if self.sparse_mode == "prune":
            self.register_buffer("act_sum", torch.zeros(self.num_heads))     # ∑|h|
            self.register_buffer("act_cnt", torch.tensor(0, dtype=torch.long))       # batch 计数

    def forward(self, *args, **kwargs):
        if self.training:
            return self.forward_train(*args, **kwargs)
        else:
            return self.forward_inference(*args, **kwargs)

    def forward_train(
        self,
        packed_sequence: torch.Tensor,
        sample_lens: List[int],
        attention_mask,
        packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        packed_und_token_indexes: torch.LongTensor,
        packed_gen_token_indexes: torch.LongTensor,
    ):
        packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim))
        packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
        packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))

        packed_sequence_und = packed_sequence[packed_und_token_indexes]
        packed_sequence_gen = packed_sequence[packed_gen_token_indexes]

        packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
        packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen)

        packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
        packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen)

        packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
        packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen)

        packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
        packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
        packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
        if self.config.freeze_und:
            packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach()

        packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
        packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)

        packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes])
        if self.config.freeze_und:
            packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach()
        packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes])

        packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes])
        if self.config.freeze_und:
            packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach()
        packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes])

        packed_cos, packed_sin = packed_position_embeddings
        packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
            packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1
        )

        if isinstance(attention_mask, List):
            packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
            packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim)
            packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
            packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)

            unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1)
            unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1)
            unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
            upacked_attn_output = []
            for query_states, key_states, value_states, attention_mask_per_sample in zip(
                unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
            ):
                with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
                    attn_output = scaled_dot_product_attention(
                        query_states.to(torch.bfloat16).unsqueeze(0), 
                        key_states.to(torch.bfloat16).unsqueeze(0), 
                        value_states.to(torch.bfloat16).unsqueeze(0),
                        attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
                    )
                upacked_attn_output.append(attn_output.squeeze(0))
            packed_attn_output = torch.cat(upacked_attn_output, dim=1)
        else:
            pad_size = sum(sample_lens) - packed_query_states.shape[0]
            packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size)
            packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size)
            packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
            packed_attn_output = flex_attention(
                packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim
                packed_key_states_.unsqueeze(0), 
                packed_value_states.unsqueeze(0), 
                enable_gqa=True,
                block_mask=attention_mask,
            )
            end_index = packed_attn_output.shape[2] - pad_size
            packed_attn_output = packed_attn_output[0, :, :end_index, :]

        packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim)
        packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
        packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes])
        packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes])

        return packed_attn_output_

    def forward_inference(
        self,
        packed_query_sequence: torch.Tensor,
        query_lens: torch.Tensor,
        packed_query_position_embeddings: torch.Tensor,
        packed_query_indexes: torch.Tensor,
        past_key_values: Optional[NaiveCache] = None,
        key_values_lens: Optional[torch.Tensor] = None,
        packed_key_value_indexes: Optional[torch.Tensor] = None,
        update_past_key_values=True,
        is_causal=True,
        mode="und",
        packed_vae_token_indexes=None,
        packed_text_indexes=None,
    ):
        if mode == 'und':
            packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
            packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
            packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
            packed_query_states = self.q_norm(packed_query_states)
            packed_key_states = self.k_norm(packed_key_states)
        elif mode == 'gen':
            packed_query_sequence = packed_query_sequence.to(torch.bfloat16)
            packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim))
            packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
            packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))

            packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
            packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]

            packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence)
            packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence)

            packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence)
            packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence)

            packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence)
            packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence)

            packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
            packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
            packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)

            packed_query_states = packed_query_states.to(torch.float32)
            packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes])
            packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes])

            packed_key_states = packed_key_states.to(torch.float32)
            packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes])
            packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes])

        packed_cos, packed_sin = packed_query_position_embeddings
        packed_query_states, packed_key_states = apply_rotary_pos_emb(
            packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
        )

        packed_query_states = packed_query_states.to(torch.bfloat16)
        packed_key_states = packed_key_states.to(torch.bfloat16)
        packed_value_states = packed_value_states.to(torch.bfloat16)

        if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:

            past_key_states = past_key_values.key_cache[self.layer_idx]
            past_value_states = past_key_values.value_cache[self.layer_idx]

            seqlens = sum(query_lens) + sum(key_values_lens)
            merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
            merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])

            merged_key_states[packed_query_indexes] = packed_key_states
            try:
                merged_key_states[packed_key_value_indexes] = past_key_states
            except RuntimeError: ##### uneven KV-Cache when doing image editing
                pass
            merged_value_states[packed_query_indexes] = packed_value_states
            try:
                merged_value_states[packed_key_value_indexes] = past_value_states
            except RuntimeError: ##### uneven KV-Cache when doing image editing
                pass
            key_values_lens = key_values_lens + query_lens
        else:
            merged_key_states = packed_key_states
            merged_value_states = packed_value_states
            key_values_lens = query_lens

        cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
        cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))

        packed_attn_output = flash_attn_varlen_func( # squence length, head, head_dim
            q=packed_query_states,
            k=merged_key_states,
            v=merged_value_states,
            cu_seqlens_q=cu_seqlens_q.to(torch.int32),
            cu_seqlens_k=cu_seqlens_k.to(torch.int32),
            max_seqlen_q=max(query_lens).item(),
            max_seqlen_k=max(key_values_lens).item(),
            causal=is_causal,
        )

        # und head i, gen head i

        # print(f"packed_attn_output: {packed_attn_output.size()}") 

        if self.sparse_mode == "prune":
            if mode == "und":
                # print(f"self.o_proj.weight: {self.o_proj.weight.size()}")
                W_und = self.o_proj.weight.view(self.hidden_size, self.num_heads, self.head_dim).transpose(0, 1)    # (H, Dh, Dv)
                head_out = torch.einsum("shv,hdv->hsd", packed_attn_output, W_und)
            else: 
                W_und = self.o_proj.weight.view(self.hidden_size, self.num_heads, self.head_dim).transpose(0, 1)
                W_gen = self.o_proj_moe_gen.weight.view(self.hidden_size, self.num_heads, self.head_dim).transpose(0, 1)

                head_out_und = torch.einsum("shv,hdv->hsd", packed_attn_output[packed_text_indexes], W_und)
                head_out_gen = torch.einsum("shv,hdv->hsd", packed_attn_output[packed_vae_token_indexes], W_gen)
                head_out = torch.cat((head_out_und, head_out_gen), dim=1)

            # print(f"head_out: {head_out.size()}")   
            act_sum = head_out.abs().mean(dim=(-2, -1))          # (H,)
            # print(f"act_sum: {act_sum.size()}")  
            # print(f"act_sum: {act_sum}")    
            #### # TODO  (H, Dh, Dv) matmul (B, H, S, Dv) --> (B, H, S, D_h)
            # print(f"self.act_sum: {self.act_sum.size()}")  
            self.act_sum += act_sum
            self.act_cnt += 1


        packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
        if mode == 'und':
            packed_attn_output = self.o_proj(packed_attn_output)
        elif mode == 'gen':
            packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes])
            packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes])

        # print(f"packed_attn_output: {packed_attn_output}")
        if update_past_key_values:
            past_key_values.key_cache[self.layer_idx] = merged_key_states
            past_key_values.value_cache[self.layer_idx] = merged_value_states

        return packed_attn_output, past_key_values


class Qwen2MoTDecoderLayer(nn.Module):
    def __init__(
        self, 
        config, 
        layer_idx: Optional[int] = None, 
        attn_module: Optional[Qwen2Attention] = PackedAttentionMoT,
    ):
        super().__init__()

        self.config = config
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        self.freeze_und = config.freeze_und

        self.self_attn = attn_module(config, layer_idx)

        self.transformed_und = config.transformed_und if hasattr(config, "transformed_und") else False  
        self.transformed_gen = config.transformed_gen if hasattr(config, "transformed_gen") else False  

        self.num_experts = config.num_experts if hasattr(config, "num_experts") else 8
        self.num_shared_experts = config.num_shared_experts if hasattr(config, "num_shared_experts") else 1
        self.top_k = config.top_k if hasattr(config, "top_k") else 2
        # self.norm_topk_prob = config.norm_topk_prob if hasattr(config, "norm_topk_prob") else True
        # self.stegate = config.stegate if hasattr(config, "stegate") else True

        config.shared_expert_intermediate_size = config.intermediate_size // (self.num_experts + self.num_shared_experts) * self.num_shared_experts if self.num_shared_experts is not None else None       
        remaining = config.intermediate_size - config.shared_expert_intermediate_size if config.intermediate_size is not None and config.shared_expert_intermediate_size is not None else None
        moe_inter = remaining // self.num_experts if self.num_experts is not None else None
        config.moe_intermediate_size = moe_inter

        # print(f"self.layer_idx: {self.layer_idx}")
        # print(f"self.transformed_und: {self.transformed_und}")

        if isinstance(self.transformed_und, bool):
            if not self.transformed_und:
                self.mlp = Qwen2MLP_Sparse(config, layer_idx, mode="und")
            else:
                self.mlp = Qwen2MoeSparseMoeBlock(config, mode="und")
        elif isinstance(self.transformed_und, list):
            if self.layer_idx not in self.transformed_und:        
                self.mlp = Qwen2MLP_Sparse(config, layer_idx, mode="und")
            else: 
                self.mlp = Qwen2MoeSparseMoeBlock(config, mode="und")
        else: 
            raise NotImplementedError("transformed_und must be a bool or a list of int.")

        if isinstance(self.transformed_gen, bool):
            if not self.transformed_gen:
                self.mlp_moe_gen = Qwen2MLP_Sparse(config, layer_idx, mode="gen")
            else:
                self.mlp_moe_gen = Qwen2MoeSparseMoeBlock(config, mode="gen")
        elif isinstance(self.transformed_gen, list):
            if self.layer_idx not in self.transformed_gen:
                self.mlp_moe_gen = Qwen2MLP_Sparse(config, layer_idx, mode="gen")
            else: 
                self.mlp_moe_gen = Qwen2MoeSparseMoeBlock(config, mode="gen")
        else:
            raise NotImplementedError("transformed_gen must be a bool or a list of int.")

        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.record = False
        self.skip_type = None
        self.skip_mode = None
        self.input = None
        self.residual = None
        self.output = None

    def convert_dense_to_sparse_moe_dual(
            self, 
            mode, 
            importance_scores: Optional[torch.Tensor] = None,    # 1D, 长度=intermediate_size
            shared_ratio: float = 0.05,          # 前多少比例进入共享专家
            sorted_idx: Optional[torch.Tensor] = None,    # 已经按重要性排好序的索引
            shared_idx: Optional[torch.Tensor] = None,
        ):
        """
        Convert a dense `Qwen2MLP` into a `Qwen2MoeSparseMoeBlock`.
        """
        if mode == "und":
            self.mlp = self.dense_to_moe_with_explicit_shared(
                                                                self.mlp, 
                                                                mode, 
                                                                importance_scores=importance_scores, 
                                                                shared_ratio=shared_ratio, 
                                                                sorted_idx=sorted_idx, 
                                                                shared_idx=shared_idx, 
                                                                )
        elif mode == "gen":
            self.mlp_moe_gen = self.dense_to_moe_with_explicit_shared(
                                                                self.mlp_moe_gen, 
                                                                mode, 
                                                                importance_scores=importance_scores, 
                                                                shared_ratio=shared_ratio, 
                                                                sorted_idx=sorted_idx, 
                                                                shared_idx=shared_idx, 
                                                                )

    def dense_to_moe_with_explicit_shared(
        self, 
        dense_mlp,
        mode, 
        importance_scores: Optional[torch.Tensor] = None,    # 1D, 长度=intermediate_size
        shared_ratio: float = 0.05,          # 前多少比例进入共享专家
        sorted_idx: Optional[torch.Tensor] = None,    # 已经按重要性排好序的索引
        shared_idx: Optional[torch.Tensor] = None,
    ) -> Qwen2MoeSparseMoeBlock:

        num_experts = self.num_experts
        top_k = self.top_k
        device = next(dense_mlp.parameters()).device
        dtype  = next(dense_mlp.parameters()).dtype                                      # target

        hidden = dense_mlp.hidden_size
        dense_inter = dense_mlp.intermediate_size

        # ─────────── 0. 根据 importance 排序 ───────────
        if sorted_idx is None:
            if importance_scores is None:
                raise ValueError("importance_scores must be a tensor.")
            importance_scores = importance_scores.flatten()
            sorted_idx = torch.argsort(importance_scores, descending=True)
        else:
            raise NotImplementedError("sorted_idx or importance_scores must be a tensor.")

        # importance_scores = torch.rand(importance_scores.shape, device=importance_scores.device)


        # if shared_idx is None:
        shared_num = int(dense_inter * shared_ratio)
        shared_idx = sorted_idx[:shared_num]
        # print(f"sorted_idx: {sorted_idx}")
        # print(f"shared_idx: {shared_idx}")

        # ────────────────────────────────────────────
        remaining_idx = sorted_idx[shared_num:]
        num_remaining = remaining_idx.numel()

        assert len(remaining_idx) % num_experts == 0
        # ─────────── 1. 把剩余 idx 均匀分到 experts ───────────
        # expert_indices = [remaining_idx[i::num_experts] for i in range(num_experts)]

        # ─────────── 2. 把剩余 idx 蛇形分到 experts ───────────
        chunks = remaining_idx.view(-1, num_experts)  
        chunks[1::2] = torch.flip(chunks[1::2], dims=[1])
        expert_indices = [chunks[:, i].contiguous().view(-1) for i in range(num_experts)]

        # ─────────── 3. 把剩余 idx 连续分块给 experts ───────────
        # assert num_remaining % num_experts == 0, \
        # f"Cannot evenly split {num_remaining} parameters into {num_experts} experts"
        # expert_indices = torch.chunk(remaining_idx, num_experts)  # 每块大小一样

        # ─────────── 3. 按位置顺序划分给 experts ───────────
        # all_idx = torch.arange(dense_inter, dtype=torch.long)
        # remaining_idx_positional = torch.tensor(
        # [i for i in all_idx if i not in shared_idx],
        #     dtype=torch.long
        # )
        # # Positional split (完全按原始位置顺序，忽略重要性)
        # assert len(remaining_idx_positional) % num_experts == 0, \
        #     f"Cannot evenly positional split {len(remaining_idx_positional)} into {num_experts} experts"
        # chunk_size = len(remaining_idx_positional) // num_experts
        # expert_indices = [
        #     remaining_idx_positional[i*chunk_size : (i+1)*chunk_size]
        #     for i in range(num_experts)
        # ]

        # ─────────── 2. 构造 config ───────────
        cfg = copy.deepcopy(self.config)
        cfg.num_experts                     = num_experts
        cfg.num_experts_per_tok             = top_k
        cfg.shared_expert_intermediate_size = shared_idx.numel()
        cfg.moe_intermediate_size           = num_remaining // num_experts

        moe_block = Qwen2MoeSparseMoeBlock(cfg, mode).to(device=device, dtype=dtype)

        # ─────────── 3. 拷贝权重 ───────────
        dense_cpu = dense_mlp.to("cpu", dtype=dtype).eval()
        with torch.no_grad():
            # 共享专家
            se = moe_block.shared_expert
            if se is not None:
                se.gate_proj.weight.copy_(dense_cpu.gate_proj.weight.index_select(0, shared_idx))
                se.up_proj.weight  .copy_(dense_cpu.up_proj.weight .index_select(0, shared_idx))
                se.down_proj.weight.copy_(dense_cpu.down_proj.weight.index_select(1, shared_idx))

            # 普通 experts
            for e, idx in enumerate(expert_indices):
                if moe_block.experts is None:
                    break
                exp = moe_block.experts[e]
                exp.gate_proj.weight.copy_(dense_cpu.gate_proj.weight.index_select(0, idx))
                exp.up_proj.weight  .copy_(dense_cpu.up_proj.weight .index_select(0, idx))
                exp.down_proj.weight.copy_(dense_cpu.down_proj.weight.index_select(1, idx))

            # 初始化 gate
            if hasattr(moe_block, "gate"):
                nn.init.zeros_(moe_block.gate.weight)

            if hasattr(moe_block, "gate_img"):
                nn.init.zeros_(moe_block.gate_img.weight)

            # if hasattr(moe_block, "gating_weight"):
            #     nn.init.zeros_(moe_block.gating_weight)
            #     nn.init.normal_(moe_block.e_score_correction_bias, mean=0.0, std=0.02)

            if hasattr(moe_block, "gate_shared"):
                nn.init.zeros_(moe_block.gate_shared.weight)
                    
        del dense_cpu, dense_mlp
        gc.collect(); torch.cuda.empty_cache()

        return moe_block

    def convert_dense_to_sparse_moe_dual_old(
            self, 
            mode, 
            shared_idx, 
        ):
        """
        Convert a dense `Qwen2MLP` into a `Qwen2MoeSparseMoeBlock`.
        """
        if mode == "und":
            self.mlp = self.dense_to_moe_with_explicit_shared_old(self.mlp, shared_idx, mode)
        elif mode == "gen":
            self.mlp_moe_gen = self.dense_to_moe_with_explicit_shared_old(self.mlp_moe_gen, shared_idx, mode)

    def dense_to_moe_with_explicit_shared_old(
        self, 
        dense_mlp,
        shared_idx,                               # list[int] | 1-D LongTensor
        mode="und", 
    ) -> Qwen2MoeSparseMoeBlock:
        """
        把 dense `Qwen2MLP` 转成 `Qwen2MoeSparseMoeBlock`，其中
        - `shared_idx` 指定的重要行/列当作 **共享专家**；
        - 其余行/列平均分给 `num_experts` 个 routed experts。
        """

        num_experts = self.num_experts
        num_shared_experts = self.num_shared_experts
        
        top_k = self.top_k
        # norm_topk_prob = self.norm_topk_prob
        # stegate = self.stegate
        # if device is None:
        device = next(dense_mlp.parameters()).device
        device = torch.device(device)
        dtype  = next(dense_mlp.parameters()).dtype                                      # target

        hidden = dense_mlp.hidden_size
        dense_inter = dense_mlp.intermediate_size

        # ─────────── 0. 预处理索引 ────────────────────────────────
        dense_inter = dense_mlp.intermediate_size
        shared_idx = torch.as_tensor(shared_idx, dtype=torch.long).unique(sorted=True)
        # assert shared_idx.min() >= 0 and shared_idx.max() < dense_inter, "index out of range"

        remaining_idx = torch.tensor(
            [i for i in range(dense_inter) if i not in shared_idx], dtype=torch.long
        )
        # 均匀拆分剩余索引
        # routed_chunks = torch.chunk(remaining_idx, num_experts)
        # expert_indices = [chunk.to(torch.long) for chunk in routed_chunks]


        # ① round-robin 先分片
        expert_indices = [remaining_idx[i::num_experts]  # 步长 = num_experts
                        for i in range(num_experts)]

        # ② 计算最短片段长度
        min_len = min(idx.numel() for idx in expert_indices)

        # ③ 截断到等长
        expert_indices = [idx[:min_len] for idx in expert_indices]

        mof_sizes = [len(c) for c in expert_indices]
        max_mof = max(mof_sizes) if mof_sizes else 1

        # ─────────── 1. 构造新 config ─────────────────────────────
        cfg = copy.deepcopy(self.config)
        cfg.num_experts                     = num_experts
        cfg.top_k                           = top_k
        # cfg.norm_topk_prob                  = norm_topk_prob
        # cfg.stegate                         = stegate
        cfg.shared_expert_intermediate_size = shared_idx.numel()
        cfg.moe_intermediate_size           = max_mof         # 上限，宽度不齐也可运行

        # ─────────── 2. 创建 MoE 块 (bf16) ───────────────────────
        moe_block = Qwen2MoeSparseMoeBlock(cfg, mode).to(device=device, dtype=dtype)

        # ─────────── 3. 把 dense 权重搬到 CPU+bf16 后拷贝 ───────
        dense_cpu = dense_mlp.to("cpu", dtype=dtype).eval()
        with torch.no_grad():
            # —— 3.1 共享专家 —— #
            if len(shared_idx) > 0:
                se = moe_block.shared_expert
                se.gate_proj.weight.copy_(dense_cpu.gate_proj.weight.index_select(0, shared_idx))
                se.up_proj.weight  .copy_(dense_cpu.up_proj.weight .index_select(0, shared_idx))
                se.down_proj.weight.copy_(dense_cpu.down_proj.weight.index_select(1, shared_idx))

            # —— 3.2 routed experts —— #
            if top_k > 0:
                for e, idx in enumerate(expert_indices):
                    exp = moe_block.experts[e]
                    if idx.numel():                                # 防止空 slice
                        exp.gate_proj.weight.copy_(dense_cpu.gate_proj.weight.index_select(0, idx))
                        exp.up_proj.weight  .copy_(dense_cpu.up_proj.weight .index_select(0, idx))
                        exp.down_proj.weight.copy_(dense_cpu.down_proj.weight.index_select(1, idx))

                # 初始化 gate
                if hasattr(moe_block, "gate"):
                    nn.init.zeros_(moe_block.gate.weight)

                if hasattr(moe_block, "gate_img"):
                    nn.init.zeros_(moe_block.gate_img.weight)

                # if hasattr(moe_block, "gating_weight"):
                #     nn.init.zeros_(moe_block.gating_weight)
                #     nn.init.normal_(moe_block.e_score_correction_bias, mean=0.0, std=0.02)

                if hasattr(moe_block, "gate_shared"):
                    nn.init.zeros_(moe_block.gate_shared.weight)
                
        # ─────────── 4. 清理显存 ─────────────────────────────────
        del dense_cpu, dense_mlp
        gc.collect(); torch.cuda.empty_cache()

        return moe_block

    def forward(self, *args, **kwargs):
        if self.training:
            return self.forward_train(*args, **kwargs)
        else:
            return self.forward_inference(*args, **kwargs)

    def forward_train(
        self,
        packed_sequence: torch.Tensor,
        sample_lens: List[int],
        attention_mask,
        packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        packed_und_token_indexes: torch.LongTensor,
        packed_gen_token_indexes: torch.LongTensor,
        packed_text_indexes: Optional[torch.LongTensor] = None,
        packed_vit_token_indexes: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:

        residual = packed_sequence
        packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
        packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes])
        packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])

        # Self Attention
        packed_sequence_ = self.self_attn(
            packed_sequence=packed_sequence_,
            sample_lens=sample_lens,
            attention_mask=attention_mask,
            packed_position_embeddings=packed_position_embeddings,
            packed_und_token_indexes=packed_und_token_indexes,
            packed_gen_token_indexes=packed_gen_token_indexes,
        )
        if self.freeze_und:
            packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
        packed_sequence = residual + packed_sequence_

        # Fully Connected
        residual = packed_sequence
        packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
        ####

        packed_sequence_[packed_und_token_indexes], routing_logits_und = self.mlp(
            self.post_attention_layernorm(packed_sequence[packed_und_token_indexes]), 
            packed_text_indexes=packed_text_indexes,
            packed_vit_token_indexes=packed_vit_token_indexes,
        )
        if self.freeze_und:
            packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
    
        packed_sequence_[packed_gen_token_indexes], routing_logits_gen = self.mlp_moe_gen(
            self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
        )
        packed_sequence = residual + packed_sequence_
        if routing_logits_und is not None and routing_logits_gen is not None:
            routing_logits = torch.cat((routing_logits_und, routing_logits_gen), dim=0)

        return packed_sequence, routing_logits

    def forward_inference(
        self,
        packed_query_sequence: torch.Tensor,
        query_lens: torch.Tensor,
        packed_query_position_embeddings: torch.Tensor,
        packed_query_indexes: torch.Tensor,
        past_key_values: Optional[NaiveCache] = None,
        key_values_lens: Optional[torch.Tensor] = None,
        packed_key_value_indexes: Optional[torch.Tensor] = None,
        update_past_key_values=True,
        is_causal=True,
        mode="und",
        packed_vae_token_indexes=None,
        packed_text_indexes: Optional[torch.LongTensor] = None,
        packed_vit_token_indexes: Optional[torch.LongTensor] = None,
    ) -> BaseNavitOutputWithPast:

        self.packed_text_indexes = packed_text_indexes
        self.packed_vae_token_indexes = packed_vae_token_indexes

        residual = packed_query_sequence
        # print(f"record: {self.record} | mode: {self.skip_mode} | residual: {residual}" )
        if self.record and mode == self.skip_mode: ## bsz = 1
            self.input = torch.cat((self.input, residual), dim=0) if self.input is not None else residual

        self.self_attn.packed_text_indexes = packed_text_indexes
        self.self_attn.packed_vae_token_indexes = packed_vae_token_indexes
        
        self.input_layernorm.mode = mode
        self.input_layernorm_moe_gen.mode = mode

        self.post_attention_layernorm.mode = mode
        self.post_attention_layernorm_moe_gen.mode = mode

        self.self_attn.mode = mode

        # print(f"{self.layer_idx} mode: {mode} | skip_mode: {self.skip_mode}")

        if self.skip_type in ["block", "attn"] and self.skip_mode == mode:
            pass

        else: 
            if mode == "und":
                packed_query_sequence = self.input_layernorm(packed_query_sequence)
            elif mode == "gen":

                packed_query_sequence_ = torch.zeros_like(packed_query_sequence)

                packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes])
                packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
                packed_query_sequence = packed_query_sequence_

            # Self Attention
            packed_query_sequence, past_key_values = self.self_attn(
                packed_query_sequence=packed_query_sequence,
                query_lens=query_lens,
                packed_query_position_embeddings=packed_query_position_embeddings,
                packed_query_indexes=packed_query_indexes,
                past_key_values=past_key_values,
                key_values_lens=key_values_lens,
                packed_key_value_indexes=packed_key_value_indexes,
                update_past_key_values=update_past_key_values,
                is_causal=is_causal,
                mode=mode,
                packed_vae_token_indexes=packed_vae_token_indexes,
                packed_text_indexes=packed_text_indexes,
            )

            packed_query_sequence = residual + packed_query_sequence            


        # Fully Connected
        residual = packed_query_sequence
        if self.record and mode == self.skip_mode: ## bsz = 1
            self.residual = torch.cat((self.residual, residual), dim=0) if self.residual is not None else residual

        routing_logits = None
        if mode == "und":
            packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
            packed_query_sequence, routing_logits_und = self.mlp(packed_query_sequence, 
                                                                packed_text_indexes=packed_text_indexes,
                                                                packed_vit_token_indexes=packed_vit_token_indexes,
                                                                )
            routing_logits = routing_logits_und
        elif mode == "gen":
            packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
            packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
            packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16)
            packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16)

            packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)

            packed_query_sequence_[packed_text_indexes], routing_logits_und = self.mlp(packed_text_query_sequence)
            packed_query_sequence_[packed_vae_token_indexes], routing_logits_gen = self.mlp_moe_gen(packed_vae_query_sequence)
            packed_query_sequence = packed_query_sequence_
            if routing_logits_und is not None and routing_logits_gen is not None:
                # print(f"routing_logits_und: {routing_logits_und.shape}")
                routing_logits = torch.cat((routing_logits_und, routing_logits_gen), dim=0)
                
        if self.record and mode == self.skip_mode: ## bsz = 1
            self.output = torch.cat((self.output, packed_query_sequence), dim=0) if self.output is not None else packed_query_sequence

        if self.skip_type in ["block", "mlp"] and mode == self.skip_mode:
            packed_query_sequence = residual
        else: 
            packed_query_sequence = residual + packed_query_sequence

        return packed_query_sequence, past_key_values, routing_logits


# class Qwen2MoEDecoderLayer(nn.Module):
#     def __init__(self, config, layer_idx: Optional[int] = None):
#         super().__init__()
#         self.layer_idx = layer_idx
#         self.hidden_size = config.hidden_size

#         self.self_attn = PackedAttention(config, layer_idx)

#         self.mlp = Qwen2MLP_Sparse(config, layer_idx, mode="und")
#         self.mlp_moe_gen = Qwen2MLP_Sparse(config, layer_idx, mode="gen")

#         self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
#         self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

#     def forward(self, *args, **kwargs):
#         if self.training:
#             return self.forward_train(*args, **kwargs)
#         else:
#             # skipped_layers = list(range(25, 27))
#             # if self.layer_idx not in skipped_layers:
#             return self.forward_inference(*args, **kwargs)

#     def forward_train(
#         self,
#         packed_sequence: torch.Tensor,
#         sample_lens: List[int],
#         attention_mask,
#         packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
#         packed_und_token_indexes: torch.LongTensor,
#         packed_gen_token_indexes: torch.LongTensor,
#     ) -> torch.Tensor:

#         residual = packed_sequence
#         packed_sequence = self.input_layernorm(packed_sequence)

#         # Self Attention
#         packed_sequence = self.self_attn(
#             packed_sequence=packed_sequence,
#             sample_lens=sample_lens,
#             attention_mask=attention_mask,
#             packed_position_embeddings=packed_position_embeddings,
#         )
#         packed_sequence = residual + packed_sequence

#         # Fully Connected
#         residual = packed_sequence
#         packed_sequence = self.post_attention_layernorm(packed_sequence)

#         packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape)
#         packed_sequence_und, routing_logits_und = self.mlp(packed_sequence[packed_und_token_indexes])
#         packed_sequence_gen, routing_logits_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes])
#         packed_sequence_new[packed_und_token_indexes] = packed_sequence_und
#         packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen

#         packed_sequence = residual + packed_sequence_new

#         if routing_logits_und is not None and routing_logits_gen is not None: 
#             routing_logits = routing_logits_und + routing_logits_gen

#         return packed_sequence

#     def forward_inference(
#         self,
#         packed_query_sequence: torch.Tensor,
#         query_lens: torch.Tensor,
#         packed_query_position_embeddings: torch.Tensor,
#         packed_query_indexes: torch.Tensor,
#         past_key_values: Optional[NaiveCache] = None,
#         key_values_lens: Optional[torch.Tensor] = None,
#         packed_key_value_indexes: Optional[torch.Tensor] = None,
#         update_past_key_values=True,
#         is_causal=True,
#         mode="und",
#         packed_vae_token_indexes=None,
#         packed_text_indexes=None,
#         gen_text=False, 
#     ) -> BaseNavitOutputWithPast:

#         residual = packed_query_sequence
#         packed_query_sequence = self.input_layernorm(packed_query_sequence)

#         # Self Attention
#         packed_query_sequence, past_key_values = self.self_attn(
#             packed_query_sequence=packed_query_sequence,
#             query_lens=query_lens,
#             packed_query_position_embeddings=packed_query_position_embeddings,
#             packed_query_indexes=packed_query_indexes,
#             past_key_values=past_key_values,
#             key_values_lens=key_values_lens,
#             packed_key_value_indexes=packed_key_value_indexes,
#             update_past_key_values=update_past_key_values,
#             is_causal=is_causal,
#         )
#         packed_query_sequence = residual + packed_query_sequence

#         # Fully Connected
#         residual = packed_query_sequence
#         packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
#         if mode == "und":
#             packed_query_sequence = self.mlp(packed_query_sequence, 
#                                             #  gen_text=gen_text
#                                              )
#         elif mode == "gen":
#             packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
#             packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes])
#             packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes])
#             packed_query_sequence = packed_query_sequence_
#         packed_query_sequence = residual + packed_query_sequence

#         return packed_query_sequence, past_key_values


Decoder_layer_dict = {
    # "Qwen2DecoderLayer": Qwen2DecoderLayer,
    # "Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer,
    "Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT),
}


class Qwen2Model(Qwen2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.use_moe = 'Mo' in config.layer_module

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        layer_module = Decoder_layer_dict[config.layer_module]
        self.layers = nn.ModuleList(
            [layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        if self.use_moe:
            self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen2RotaryEmbedding(config=config)

        self.packed_text_indexes = None
        self.packed_vae_token_indexes = None
    
        # Initialize weights and apply final processing
        # self.post_init()

    def forward(self, *args, **kwargs):
        if self.training:
            return self.forward_train(*args, **kwargs)
        else:
            return self.forward_inference(*args, **kwargs)

    def forward_train(
        self,
        packed_sequence: torch.Tensor,
        sample_lens: List[int],
        attention_mask,
        packed_position_ids: torch.Tensor,
        packed_und_token_indexes: Optional[torch.LongTensor] = None,
        packed_gen_token_indexes: Optional[torch.LongTensor] = None,
        packed_text_indexes: Optional[torch.LongTensor] = None,
        packed_vit_token_indexes: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:

        if self.config.freeze_und:
            packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach()

        # create position embeddings to be shared across the decoder layers
        cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0))
        cos = cos.squeeze(0)
        sin = sin.squeeze(0)
        packed_position_embeddings = (cos, sin)

        extra_inputs = {}
        if self.use_moe:
            assert packed_und_token_indexes is not None
            if packed_gen_token_indexes is None:
                packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0])
            extra_inputs.update(
                packed_und_token_indexes=packed_und_token_indexes,
                packed_gen_token_indexes=packed_gen_token_indexes,
                packed_text_indexes=packed_text_indexes,
                packed_vit_token_indexes=packed_vit_token_indexes,
            )

        all_router_logits = ()
        for decoder_layer in self.layers:
            packed_sequence, router_logits = decoder_layer(
                packed_sequence=packed_sequence,
                sample_lens=sample_lens,
                attention_mask=attention_mask,
                packed_position_embeddings=packed_position_embeddings,
                **extra_inputs
            )
            if router_logits is not None:
                all_router_logits += (router_logits,)

        if self.use_moe:
            packed_sequence_ = torch.zeros_like(packed_sequence)
            packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes])
            if self.config.freeze_und:
                packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
            packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes])
            return packed_sequence_, all_router_logits
        else:
            return self.norm(packed_sequence), all_router_logits

    def forward_inference(
        self,
        packed_query_sequence: torch.Tensor,
        query_lens: torch.Tensor,
        packed_query_position_ids: torch.Tensor,
        packed_query_indexes: torch.Tensor,
        past_key_values: Optional[NaiveCache] = None,
        key_values_lens: Optional[torch.Tensor] = None,
        packed_key_value_indexes: Optional[torch.Tensor] = None,
        update_past_key_values=True,
        is_causal=True,
        mode="und",
        packed_vae_token_indexes=None,
        packed_text_indexes=None,
        packed_vit_token_indexes=None,
        # gen_text=False, 
    ) -> BaseNavitOutputWithPast:


        # merge_layer_idx = 10
        merge_layer_idx = None
        
        self.packed_text_indexes = packed_text_indexes
        self.packed_vae_token_indexes = packed_vae_token_indexes

        # for i in range(len(past_key_values.key_cache)):
        #     if past_key_values.key_cache[0] is not None:
        #         if past_key_values.key_cache[i] is not None:
        #         # pass
        #             print(f"Key Cache {i}: {past_key_values.key_cache[i].size()}")
        #         else:
        #             print(f"Key Cache {i}: None")
        # if packed_query_position_ids.device != packed_query_sequence.device: 
        #     packed_query_position_ids = packed_query_position_ids.to(packed_query_sequence.device)

        # print(f"packed_query_sequence: {packed_query_sequence.device}")
        # print(f"packed_query_position_ids: {packed_query_position_ids.device}")

        
        # create position embeddings to be shared across the decoder layers
        cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0))

        cos = cos.squeeze(0)
        sin = sin.squeeze(0)
        packed_query_position_embeddings = (cos, sin)

        extra_inputs = {}
        if self.use_moe:
            extra_inputs.update(mode=mode)
            # extra_inputs.update(gen_text=gen_text)

            if mode == 'gen':
                assert packed_vae_token_indexes is not None
                assert packed_text_indexes is not None
                extra_inputs.update(
                    packed_vae_token_indexes=packed_vae_token_indexes,
                    packed_text_indexes=packed_text_indexes,
                )

        for i, decoder_layer in enumerate(self.layers):

            if mode == "gen" and merge_layer_idx is not None and i == merge_layer_idx: 
                merged_indexes = torch.cat([packed_text_indexes, packed_vae_token_indexes], dim=0)
                extra_inputs.update(
                    packed_vae_token_indexes=merged_indexes,
                    packed_text_indexes=torch.empty(0, dtype=packed_text_indexes.dtype, device=packed_text_indexes.device),
                )

            if mode == "und": 
                extra_inputs.update(
                    packed_text_indexes=packed_text_indexes,
                    packed_vit_token_indexes=packed_vit_token_indexes,
                )
            
            # print(extra_inputs.keys())
            # print("model", packed_text_indexes is None, packed_vit_token_indexes is None)


            packed_query_sequence, past_key_values, router_logits = decoder_layer(
                packed_query_sequence=packed_query_sequence,
                query_lens=query_lens,
                packed_query_position_embeddings=packed_query_position_embeddings,
                packed_query_indexes=packed_query_indexes,
                past_key_values=past_key_values,
                key_values_lens=key_values_lens,
                packed_key_value_indexes=packed_key_value_indexes,
                update_past_key_values=update_past_key_values,
                is_causal=is_causal,
                **extra_inputs,
            )

        if self.use_moe:
            if mode == "und":
                packed_query_sequence = self.norm(packed_query_sequence)
            elif mode == "gen":
                packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
                packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes])
                packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
                packed_query_sequence = packed_query_sequence_
        else:
            packed_query_sequence = self.norm(packed_query_sequence)

        return BaseNavitOutputWithPast(
            packed_query_sequence=packed_query_sequence,
            past_key_values=past_key_values,
            router_logits=router_logits, 
        )


class Qwen2ForCausalLM(Qwen2PreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        # self.post_init()

    def init_moe(self):
        for name, param in self.named_parameters():
            if "moe_gen" in name:
                original_name = name.replace("_moe_gen", "")
                param.data.copy_(self.state_dict()[original_name].data)

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(self, *args, **kwargs):
        if self.training:
            return self.forward_train(*args, **kwargs)
        else:
            return self.forward_inference(*args, **kwargs)

    def forward_train(
        self,
        packed_sequence: torch.Tensor,
        sample_lens: List[int],
        attention_mask,
        packed_position_ids: torch.Tensor,
        packed_und_token_indexes: Optional[torch.LongTensor] = None,
        packed_gen_token_indexes: Optional[torch.LongTensor] = None,
        packed_text_indexes: Optional[torch.LongTensor] = None,         #####
        packed_vit_token_indexes: Optional[torch.LongTensor] = None,    #####
    ) -> torch.Tensor:

        # print("CausalLM", packed_text_indexes is None, packed_vit_token_indexes is None)

        outputs = self.model(
            packed_sequence=packed_sequence,
            sample_lens=sample_lens,
            packed_position_ids=packed_position_ids,
            attention_mask=attention_mask,
            packed_und_token_indexes=packed_und_token_indexes,
            packed_gen_token_indexes=packed_gen_token_indexes,
            packed_text_indexes=packed_text_indexes,
            packed_vit_token_indexes=packed_vit_token_indexes,
        )
        return outputs

    def forward_inference(
        self,
        packed_query_sequence: torch.Tensor,
        query_lens: torch.Tensor,
        packed_query_position_ids: torch.Tensor,
        packed_query_indexes: torch.Tensor,
        past_key_values: Optional[NaiveCache] = None,
        key_values_lens: Optional[torch.Tensor] = None,
        packed_key_value_indexes: Optional[torch.Tensor] = None,
        update_past_key_values=True,
        is_causal=True,
        mode="und",
        packed_vae_token_indexes=None,
        packed_text_indexes: Optional[torch.LongTensor] = None,         #####
        packed_vit_token_indexes: Optional[torch.LongTensor] = None,    #####
        # gen_text=False, 
    ) -> BaseNavitOutputWithPast:

        # print("CausalLM", packed_text_indexes is None, packed_vit_token_indexes is None)
        # packed_text_indexes + packed_vit_token_indexes

        outputs = self.model(
            packed_query_sequence=packed_query_sequence,
            query_lens=query_lens,
            packed_query_position_ids=packed_query_position_ids,
            packed_query_indexes=packed_query_indexes,
            past_key_values=past_key_values,
            key_values_lens=key_values_lens,
            packed_key_value_indexes=packed_key_value_indexes,
            update_past_key_values=update_past_key_values,
            is_causal=is_causal,
            mode=mode,
            packed_vae_token_indexes=packed_vae_token_indexes,
            packed_text_indexes=packed_text_indexes,
            packed_vit_token_indexes=packed_vit_token_indexes, 
            # gen_text=gen_text,
        )

        return outputs




